Gemma/[Gemma_2]Guess_the_word.ipynb (315 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "nH85BOCo7YYk"
},
"source": [
"##### Copyright 2024 Google LLC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "9tQNAByc7U9g"
},
"outputs": [],
"source": [
"# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F7r2q0wS7bxf"
},
"source": [
"# Play with AI - Guess the word\n",
"\n",
"This cookbook illustrates how you can employ the instruction-tuned model version of Gemma as a chatbot to play \"Guess the word\" game.\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Guess_the_word.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZHrL4tqs7mYK"
},
"source": [
"## Setup\n",
"\n",
"### Select the Colab runtime\n",
"To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
"\n",
"1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
"2. Select **Change runtime type**.\n",
"3. Under **Hardware accelerator**, select **T4 GPU**.\n",
"\n",
"\n",
"### Gemma setup on Kaggle\n",
"To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
"\n",
"* Get access to Gemma on kaggle.com.\n",
"* Select a Colab runtime with sufficient resources to run the Gemma 2B model.\n",
"* Generate and configure a Kaggle username and API key.\n",
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pQEE8RoO75F-"
},
"source": [
"### Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "XsY2Ut7a76Wa"
},
"outputs": [],
"source": [
"import os\n",
"from google.colab import userdata\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\".\n",
"\n",
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
"# vars as appropriate for your system.\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ea_56Zpa78Gu"
},
"source": [
"### Install dependencies\n",
"\n",
"Install Keras and KerasNLP."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AxPjbcnC79ck"
},
"outputs": [],
"source": [
"# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
"!pip install -q -U keras-nlp\n",
"!pip install -q -U keras"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_QCPQLf8OU0"
},
"source": [
"### Create a chat helper to manage the conversation state"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2BmB5Zua8Vs0"
},
"outputs": [],
"source": [
"import re\n",
"\n",
"import keras\n",
"import keras_nlp\n",
"\n",
"# Run at half precision to fit in memory\n",
"keras.config.set_floatx(\"bfloat16\")\n",
"\n",
"gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_instruct_2b_en\")\n",
"gemma_lm.compile(sampler=\"top_k\")\n",
"\n",
"\n",
"class ChatState():\n",
" \"\"\"\n",
" Manages the conversation history for a turn-based chatbot\n",
" Follows the turn-based conversation guidelines for the Gemma family of models\n",
" documented at https://ai.google.dev/gemma/docs/formatting\n",
" \"\"\"\n",
"\n",
" __START_TURN_USER__ = \"<start_of_turn>user\\n\"\n",
" __START_TURN_MODEL__ = \"<start_of_turn>model\\n\"\n",
" __END_TURN__ = \"<end_of_turn>\\n\"\n",
"\n",
" def __init__(self, model, system=\"\"):\n",
" \"\"\"\n",
" Initializes the chat state.\n",
"\n",
" Args:\n",
" model: The language model to use for generating responses.\n",
" system: (Optional) System instructions or bot description.\n",
" \"\"\"\n",
" self.model = model\n",
" self.system = system\n",
" self.history = []\n",
"\n",
" def add_to_history_as_user(self, message):\n",
" \"\"\"\n",
" Adds a user message to the history with start/end turn markers.\n",
" \"\"\"\n",
" self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)\n",
"\n",
" def add_to_history_as_model(self, message):\n",
" \"\"\"\n",
" Adds a model response to the history with start/end turn markers.\n",
" \"\"\"\n",
" self.history.append(self.__START_TURN_MODEL__ + message)\n",
"\n",
" def get_history(self):\n",
" \"\"\"\n",
" Returns the entire chat history as a single string.\n",
" \"\"\"\n",
" return \"\".join([*self.history])\n",
"\n",
" def get_full_prompt(self):\n",
" \"\"\"\n",
" Builds the prompt for the language model, including history and system description.\n",
" \"\"\"\n",
" prompt = self.get_history() + self.__START_TURN_MODEL__\n",
" if len(self.system)>0:\n",
" prompt = self.system + \"\\n\" + prompt\n",
" return prompt\n",
"\n",
" def send_message(self, message):\n",
" \"\"\"\n",
" Handles sending a user message and getting a model response.\n",
"\n",
" Args:\n",
" message: The user's message.\n",
"\n",
" Returns:\n",
" The model's response.\n",
" \"\"\"\n",
" self.add_to_history_as_user(message)\n",
" prompt = self.get_full_prompt()\n",
" response = self.model.generate(prompt, max_length=4096)\n",
" result = response.replace(prompt, \"\") # Extract only the new response\n",
" self.add_to_history_as_model(result)\n",
" return result\n",
"\n",
" def show_history(self):\n",
" for h in self.history:\n",
" print(h)\n",
"\n",
"\n",
"chat = ChatState(gemma_lm)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_1jyCoRd8EwX"
},
"source": [
"## Play the game"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "zoWDt87V83rZ"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Choose your theme: animal\n",
"Guess what I'm thinking.\n",
"Type \"quit\" if you want to quit.\n",
"A playful, furry swimmer. Known for its playful antics and clever use of tools. A member of the weasel-like family with a distinctive, waterproof coat. \n",
"<end_of_turn>\n",
"\n",
"> platypus\n",
"A creature that spends its days by a river or lake, often seen splashing and diving with its distinctive, thick fur. It's known for its playful demeanor and ability to hold a surprisingly large amount of water in its paws. \n",
"<end_of_turn>\n",
"\n",
"> beaver\n",
"A small, aquatic mammal with a playful, curious nature, often spotted near water, known for its distinctive, waterproof fur and love for swimming. \n",
"<end_of_turn>\n",
"\n",
"> otter\n",
"Correct!\n"
]
}
],
"source": [
"theme = input(\"Choose your theme: \")\n",
"setup_message = f\"Generate a random single word from {theme}.\"\n",
"\n",
"chat.history.clear()\n",
"answer = chat.send_message(setup_message).split()[0]\n",
"answer = re.sub(r\"\\W+\", \"\", answer) # excludes all numbers, letters and '_'\n",
"chat.history.clear()\n",
"cmd_exit = \"quit\"\n",
"question = f'Describe the word \"{answer}\" without saying it.'\n",
"\n",
"resp = \"\"\n",
"while resp.lower() != answer.lower() and resp != cmd_exit:\n",
" text = chat.send_message(question)\n",
" if resp == \"\":\n",
" print(f'Guess what I\\'m thinking.\\nType \"{cmd_exit}\" if you want to quit.')\n",
" remove_answer = re.compile(re.escape(answer), re.IGNORECASE)\n",
" text = remove_answer.sub(\"XXXX\", text)\n",
" print(text)\n",
" resp = input(\"\\n> \")\n",
"\n",
"if resp == cmd_exit:\n",
" print(f\"The answer was {answer}.\\n\")\n",
"else:\n",
" print(\"Correct!\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "[Gemma_2]Guess_the_word.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}